import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from functools import reduce

# define function to create PCA plots


def myplot(score, coeff, labels=None):
    xs = score[:, 0]
    ys = score[:, 1]
    n = coeff.shape[0]
    scalex = 1.0 / (xs.max() - xs.min())
    scaley = 1.0 / (ys.max() - ys.min())
    scatter = plt.scatter(xs * scalex, ys * scaley, s=2)
    plt.legend(*scatter.legend_elements())
    for i in range(n):
        plt.arrow(0, 0, coeff[i, 0], coeff[i, 1], color="r", alpha=0.5)
        if labels is None:
            plt.text(
                coeff[i, 0] * 1.15,
                coeff[i, 1] * 1.15,
                emplbat[i],
                color="g",
                ha="center",
                va="center",
            )
        else:
            plt.text(
                coeff[i, 0] * 1.15,
                coeff[i, 1] * 1.15,
                labels[i],
                color="g",
                ha="center",
                va="center",
            )


# set up data

# list employment items


compensation_benefits = ["Q30", "Q31", "Q32"]

physical_environment = ["Q33"]

job_security = ["Q37", "Q34"]

autonomy_engagement_fairness = ["Q35", "Q39", "Q46", "Q40", "Q36"]

training_opportunities = ["Q38", "Q44", "Q45", "Q47"]

work_times = ["Q41"]

emplbat = (
    compensation_benefits
    + training_opportunities
    + autonomy_engagement_fairness
    + job_security
    + work_times
    + physical_environment
)

# list attitude items

antisystem = ["Q" + str(i) for i in range(102, 109)]

behavior = ["Q" + str(i) for i in range(110, 115)]

populism = ["Q" + str(i) for i in range(116, 125)]

politems = antisystem + behavior + populism

# covariates

covs_rename = [
    "gender_q",
    "race",
    "education_q",
    "n_employees",
    "industry_q",
    "position_q",
    "working_time",
    "contract_type_q",
    "years_in_company",
    "party_close",
    "party_voted",
    "ResponseId",
    "firm_attention",
    "informal"
]

survey = pd.read_csv("data/qualtrics/brazil/brazil_clean.csv")

survey[antisystem] = survey[antisystem].replace(range(1, 6), range(5, 0, -1))

survey = survey[emplbat + politems + covs_rename].dropna()

# attitudes and employment PCA

reducer = PCA(n_components=1)

reducer_two = PCA(n_components=2)

batt_names = [
    "antisystem",
    "behavior",
    "populism",
    "compensation_benefits",
    "physical_environment",
    "job_security",
    "autonomy_engagement_fairness",
    "training_opportunities",
    "work_times",
]

batteries = [
    antisystem,
    behavior,
    populism,
    compensation_benefits,
    physical_environment,
    job_security,
    autonomy_engagement_fairness,
    training_opportunities,
    work_times,
]


battdict = dict(zip(batt_names, batteries))

reduced_dfs = []

for batt in batt_names:
    qids = battdict[batt]

    survey_qids = StandardScaler().fit_transform(survey[qids])

    qids_reduced = reducer.fit_transform(survey_qids)

    if len(qids) >= 2:
        qids_reduced_two = reducer_two.fit_transform(survey_qids)

        plt.figure(figsize=(12, 8))
        plt.xlim(-1, 1)
        plt.ylim(-1, 1)
        plt.xlabel("PC{}".format(1))
        plt.ylabel("PC{}".format(2))
        plt.grid()

        myplot(qids_reduced_two, np.transpose(reducer_two.components_), labels=qids)
        plt.savefig("plots/qualtrics_br/pca_" + batt + ".eps", format="eps")

    qids_reduced = StandardScaler().fit_transform(qids_reduced)

    qids_reduced = pd.DataFrame(qids_reduced)

    qids_reduced.columns = [batt]

    reduced_dfs.append(qids_reduced)

df_reduced = reduce(lambda left, right: pd.concat([left, right], axis=1), reduced_dfs)

reducer = PCA(n_components=2)

survey_empl = StandardScaler().fit_transform(survey[emplbat])

empl_reduced = reducer.fit_transform(survey_empl)

survey_pol = StandardScaler().fit_transform(survey[politems])

pol_reduced = reducer.fit_transform(survey_pol)

pol_reduced = pd.DataFrame(pol_reduced)
empl_reduced = pd.DataFrame(empl_reduced)

reduced = pd.concat([pol_reduced, empl_reduced], axis=1)

reduced = StandardScaler().fit_transform(reduced)

reduced = pd.DataFrame(reduced)

reduced.columns = ["pol_1", "pol_2", "empl_1", "empl_2"]

survey = pd.concat([survey.reset_index(), df_reduced], axis=1)

survey = pd.concat([survey.reset_index(), reduced], axis=1)

# make scales consistent

# higher values indicate more anti systemness

survey["antisystem"] = -survey["antisystem"]

survey["populism"] = -survey["populism"]

# higher values indicate better employment conditions

survey["physical_environment"] = -survey["physical_environment"]

survey["compensation_benefits"] = -survey["compensation_benefits"]

survey["work_times"] = -survey["work_times"]

survey["job_security"] = -survey["job_security"]

survey["pol_1"] = -survey["pol_1"]

survey["autonomy_engagement_fairness"] = -survey["autonomy_engagement_fairness"]

survey["training_opportunities"] = -survey["training_opportunities"]

# save pca reduced survey

survey.to_csv("data/qualtrics/brazil_pca.csv")

# correlation plot

corrs = survey.groupby(["firm_attention"])[emplbat + politems].apply(
    lambda x: x - x.mean()
)

corrs = corrs.corr()

politems = antisystem + populism + behavior

corrs = corrs.loc[emplbat, politems]

pollabels = [
    "Democracy",
    "Best gov. system",
    "Rights protected",
    "Well represented",
    "Resp. institutions",
    "Elections",
    "Elites - People",
    "Common citizen",
    "Polits. talk much",
    "Immigration",
    "Social media",
    "Taxes",
    "Protests",
    "Overthrow gov.",
    "AS Parties",
]


def draw_brace(ax, xspan, yy, text, axis="x"):
    """Draws an annotated brace outside the axes."""
    xmin, xmax = xspan
    xspan = xmax - xmin
    ax_xmin, ax_xmax = ax.get_xlim()
    xax_span = ax_xmax - ax_xmin

    ymin, ymax = ax.get_ylim()
    yspan = ymax - ymin
    resolution = int(xspan / xax_span * 100) * 2 + 1  # guaranteed uneven
    beta = 300.0 / xax_span  # the higher this is, the smaller the radius

    x = np.linspace(xmin, xmax, resolution)
    x_half = x[: int(resolution / 2) + 1]
    y_half_brace = 1 / (1.0 + np.exp(-beta * (x_half - x_half[0]))) + 1 / (
        1.0 + np.exp(-beta * (x_half - x_half[-1]))
    )
    y = np.concatenate((y_half_brace, y_half_brace[-2::-1]))
    y = yy + (0.05 * y - 0.01) * yspan  # adjust vertical position

    ax.autoscale(False)
    if axis == "x":
        ax.plot(x, -y, color="black", lw=1, clip_on=False)
        ax.text(
            (xmax + xmin) / 2.0,
            -yy - 0.12 * yspan,
            text,
            fontsize=14,
            ha="center",
            va="bottom",
        )
    if axis == "y":
        ax.plot(-y, x, color="black", lw=1, clip_on=False)
        ax.text(
            -yy - 0.12 * yspan,
            (xmax + xmin) / 2.0,
            text,
            fontsize=14,
            ha="center",
            va="center",
        )


empllabels = [
    "Fair pay",
    "Pay + Benefits",
    "Pay matches performance",
    "Training satisfaction",
    "Opportunities for development",
    "Can achieve career objectives",
    "Good career opportunities",
    "Involvement in decisions",
    "Self-realization",
    "Values contributions",
    "Equal opportunities",
    "Recommend to others",
    "Stability - No risk of layoff",
    "Succesful firm",
    "Declining firm (rec)",
    "Working times - Work/Life balance",
    "Physical environment",
]

fig, ax = plt.subplots(figsize=(16, 12))
sns.heatmap(corrs, annot=True, cbar=False, fmt=".2f", cmap="vlag", center=0)
ax.xaxis.set_ticks_position("top")
draw_brace(ax, (0, 5.95), -17, "Anti-System \n Attitudes")
draw_brace(ax, (6.05, 8.95), -17, "Populist \n Attitudes")
draw_brace(ax, (9.05, 15), -17, "Anti-System \n Behaviors")
draw_brace(ax, (0, 2.95), -14.8, "Pay & Benefits", axis="y")
draw_brace(ax, (3.05, 6.95), -14.8, "Opportunities \n & Training", axis="y")
draw_brace(ax, (7.05, 11.95), -14.8, "Engagement \n & Fairness", axis="y")
draw_brace(ax, (12.05, 14.95), -14.8, "Job Security", axis="y")
plt.xticks(ticks=np.arange(0.5, len(pollabels) + 0.5, 1), labels=pollabels, rotation=45)
plt.yticks(
    ticks=np.arange(0.5, len(empllabels) + 0.5, 1), labels=empllabels, rotation=360
)
plt.savefig(
    "plots/qualtrics_br/correlation_plot.eps", format="eps", bbox_inches="tight"
)


